/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "common_reg.h"
#include "execute_graph_builder.h"
#include "ge/infer_shape_functions.h"
#include "ge/infer_shape_mat_mul.h"
#include "op_def.h"
namespace ge {
Status TestTilingFunc(const Operator &, void *) {
  return SUCCESS;
}

OpDefRegistry RegOpDefs(ExecuteGraphBuilder &gb) {
  OpDefRegistry reg;
  reg.data = gb.AddOpDef("Data");
  gb.GetOpDefBuilder(reg.data)
      ->SetInputDefCount(1)
      .SetInputDef(0, {kMustIo, "x"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kMustIo, "y"})
      .SetAttrDefCount(1)
      .SetAttrDef(0, {kOptionalIo, "index", []() { return AnyValue(0); }})
      .SetInferShapeFunc(ops::DataInferShape);

  reg.const_offset = gb.AddOpDef("Const");
  gb.GetOpDefBuilder(reg.const_offset)
      ->SetOutputDefCount(1)
      .SetOutputDef(0, {kMustIo, "y"})
      .SetAttrDefCount(1)
      .SetAttrDef(0, {kOptionalIo, "value", []() { return AnyValue(0); }})  // todo 正式切入主线后，这里要造一个空Tensor
      .SetInferShapeFunc(ops::DoNothingInferShape);

  reg.matmul = gb.AddOpDef("MatMul");
  gb.GetOpDefBuilder(reg.matmul)
      ->SetInputDefCount(3)
      .SetInputDef(0, {kMustIo, "x1"})
      .SetInputDef(1, {kMustIo, "x2"})
      .SetInputDef(2, {kOptionalIo, "bias"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kMustIo, "y"})
      .SetAttrDefCount(2)
      .SetAttrDef(0, {kOptionalIo, "transpose_x1", []() { return AnyValue(false); }})
      .SetAttrDef(1, {kOptionalIo, "transpose_x2", []() { return AnyValue(false); }})
      .SetInferShapeFunc(ops::MatMulInferShape)
      .SetTilingFunc(TestTilingFunc);

  reg.conv2d = gb.AddOpDef("Conv2D");
  gb.GetOpDefBuilder(reg.conv2d)
      ->SetInputDefCount(4)
      .SetInputDef(0, {kMustIo, "x"})
      .SetInputDef(1, {kMustIo, "filter"})
      .SetInputDef(2, {kOptionalIo, "bias"})
      .SetInputDef(3, {kOptionalIo, "offset_w"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kMustIo, "y"})
      .SetAttrDefCount(6)
      .SetAttrDef(0, {kMustIo, "strides", nullptr})
      .SetAttrDef(1, {kMustIo, "pads", nullptr})
      .SetAttrDef(2,
                  {kOptionalIo, "dilations",
                   []() {
                     return AnyValue(std::vector<int64_t>({1, 1, 1, 1}));
                   }})
      .SetAttrDef(3, {kOptionalIo, "groups", []() { return AnyValue(1); }})
      .SetAttrDef(4, {kOptionalIo, "data_format", []() { return AnyValue(std::string("NHWC")); }})
      .SetAttrDef(5, {kOptionalIo, "offset_x", []() { return AnyValue(0); }});

  reg.netoutput = gb.AddOpDef("NetOutput");
  gb.GetOpDefBuilder(reg.netoutput)
      ->SetInputDefCount(1)
      .SetInputDef(0, {kDynamicIo, "x"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kDynamicIo, "y"})
      .SetInferShapeFunc(ops::DoNothingInferShape);

  reg.transdata = gb.AddOpDef("TransData");
  gb.GetOpDefBuilder(reg.transdata)
      ->SetInputDefCount(1)
      .SetInputDef(0, {kMustIo, "src"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kMustIo, "dst"})
      .SetAttrDefCount(2)
      .SetAttrDef(0, {kMustIo, "src_format", nullptr})
      .SetAttrDef(1, {kMustIo, "dst_format", nullptr})
      .SetInferShapeFunc(ops::TransDataInferShape);

  reg.split = gb.AddOpDef("Split");
  gb.GetOpDefBuilder(reg.split)
      ->SetInputDefCount(2)
      .SetInputDef(0, {kMustIo, "split_dim"})
      .SetInputDef(1, {kMustIo, "x"})
      .SetOutputDefCount(1)
      .SetOutputDef(0, {kDynamicIo, "y"})
      .SetAttrDefCount(1)
      .SetAttrDef(0, {kMustIo, "num_split", nullptr});
  return reg;
}
}  // namespace ge